import io
import base64
import time

import pandas as pd
import numpy as np
from PIL import Image
import opennsfw2 as n2


def prep(img_b64):
    pil_img = Image.open(io.BytesIO(base64.b64decode(img_b64)))
    return n2.preprocess_image(pil_img, n2.Preprocessing.YAHOO)


def inference_pipeline(
    fp_output: str, 
    fp_input: str,
    batch_size: int = 1024, 
    n_samples: int = None,
):

    model = n2.make_open_nsfw_model()

    st0 = time.time()
    results = []
    with pd.read_json(
        fp_input,
        lines=True,
        chunksize=batch_size,
        nrows=n_samples,
    ) as data_gen:
        for batch in data_gen:
            img_batch = np.stack([prep(img_b64) for img_b64 in batch["img_b64"]])
            preds = model(img_batch, training=False)
            results.append((batch["qry_id"], batch["img_rank"], preds[:, 1]))
    print(f"{n_samples} images took {time.time()-st0}s")

    results = [np.concatenate(x) for x in list(zip(*results))]
    results = pd.DataFrame(results).T
    results.columns = ["qry_id", "img_rank", "nsfw_score"]
    results["qry_id"] = results["qry_id"].astype(int)
    results["img_rank"] = results["img_rank"].astype(int)
    results.to_csv(fp_output, index=False)


if __name__ == "__main__":
    inference_pipeline()
